library(rpart)
library(rattle)
## Warning: package 'rattle' was built under R version 4.2.2
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## VersiĂ³n 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Escriba 'rattle()' para agitar, sacudir y rotar sus datos.
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.6 ✔ dplyr 1.0.9
## ✔ tidyr 1.2.0 ✔ stringr 1.4.0
## ✔ readr 2.1.2 ✔ forcats 0.5.1
## ✔ purrr 0.3.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
#Vamos a crear un dataset sintético y graficarlo en 3D
set.seed(911)
n = 1000
dtrain <- data.frame(x1 = runif(n,4.5,13.5),x2 = runif(n,4.5,13.5))
noise <- rnorm(n, mean=0, sd=0.5)
dtrain <- dtrain %>% mutate(y = sqrt((x1-9)**2+(x2-9)**2)+noise)
#VisualizaciĂ³n del dataset sintĂ©tico
plot_ly(dtrain, x = ~x1, y = ~x2, z = ~y) %>%
add_markers(size = 1,color = I("orange"))
plot_ly(dtrain, x = ~x1, y = ~x2) %>%
add_markers(size = 1,color = I("orange"))
$ Y = f(x) += + $
tree <- rpart(y ~ x1 + x2, data = dtrain, method = "anova",maxdepth = 3, minsplit = 1, minbucket = 1, cp = 0)
fancyRpartPlot(tree)
fitted.values <- predict(tree, newdata = dtrain)
frame <- tree$frame
nodevec <- as.numeric(row.names(frame[frame$var == "<leaf>",])) #esto genera un vector con los nĂºmeros de nodos terminales
path.list <- path.rpart(tree, nodes = nodevec) #genera una lista en la cual cada elemento indica el camino a un nodo
##
## node number: 8
## root
## x2>=6.26
## x2< 11.41
## x1< 11.73
##
## node number: 9
## root
## x2>=6.26
## x2< 11.41
## x1>=11.73
##
## node number: 10
## root
## x2>=6.26
## x2>=11.41
## x2< 12.49
##
## node number: 11
## root
## x2>=6.26
## x2>=11.41
## x2>=12.49
##
## node number: 12
## root
## x2< 6.26
## x1>=6.345
## x1< 11.73
##
## node number: 13
## root
## x2< 6.26
## x1>=6.345
## x1>=11.73
##
## node number: 14
## root
## x2< 6.26
## x1< 6.345
## x2>=5.699
##
## node number: 15
## root
## x2< 6.26
## x1< 6.345
## x2< 5.699
rect_info <- NULL
for(path in path.list){
path <- setdiff(path,"root")
min.x1 = min(dtrain$x1)
max.x1 = max(dtrain$x1)
min.x2 = min(dtrain$x2)
max.x2 = max(dtrain$x2)
for(split in path){
s <- unlist(str_split(split,"< |>="))
var <- s[1]
cutoff <- as.numeric(s[2])
is.less <- str_detect(split,"< ")
if(var == "x1"){
if(is.less == TRUE){
max.x1 <- cutoff
} else {
min.x1 <- cutoff
}
} else {
if(is.less == TRUE){
max.x2 <- cutoff
} else {
min.x2 <- cutoff
}
}
}
rect_info <- rbind(rect_info,data.frame(xmin = min.x1, xmax = max.x1, ymin = min.x2, ymax = max.x2))
}
# rect_info <- rect_info %>% mutate(xmed = xmin+((xmax - xmin)/2),
# ymed = ymin + ((ymax - ymin)/2),
# val = levels(as.factor(round(fitted.values,2))))
#
# ggplot() +
# geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
# geom_point(data = dtrain,aes(x = x1, y = x2, color = round(fitted.values, 2))) +
# labs(color="Valor ajustado")+
# geom_label(rect_info, aes(x=xmed, y=ymed, label=val))+
# theme_light()
dtrain <- dtrain %>%
mutate(fitted.values = as.factor(round(fitted.values,2)))
label_points <- dtrain %>%
group_by(fitted.values) %>%
summarise(x = median(x1), y = median (x2))
ggplot() +
geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
geom_point(data = dtrain,aes(x = x1, y = x2, color = fitted.values)) +
geom_label(data = label_points,aes(x = x, y = y, label = fitted.values)) +
labs(color="Valor ajustado") +
theme_light()